import sys
from pathlib import Path

# Add project root to sys.path to allow running tests from this directory
# The project root is 6 levels up from the parent directory of this file.
project_root = str(Path(__file__).resolve().parents[6])
if project_root not in sys.path:
    sys.path.insert(0, project_root)

#
# Copyright © 2024 Agora
# This file is part of TEN Framework, an open source project.
# Licensed under the Apache License, Version 2.0, with certain conditions.
# Refer to the "LICENSE" file in the root directory for more information.
#
from pathlib import Path
import json
from unittest.mock import patch, MagicMock
import os
import asyncio
import filecmp
import shutil
import threading

from ten_runtime import (
    ExtensionTester,
    TenEnvTester,
    Data,
)
from ten_ai_base.struct import TTSTextInput, TTSFlush
from fish_audio_tts_python.fish_audio_tts import (
    EVENT_TTS_RESPONSE,
    EVENT_TTS_END,
    EVENT_TTS_FLUSH,
)


# ================ test dump file functionality ================
class ExtensionTesterDump(ExtensionTester):
    def __init__(self):
        super().__init__()
        # Use a fixed path as requested by the user.
        self.dump_dir = "./dump/"
        # Use a unique name for the file generated by the test to avoid collision
        # with the file generated by the extension.
        self.test_dump_file_path = os.path.join(
            self.dump_dir, "test_manual_dump.pcm"
        )
        self.audio_end_received = False
        self.received_audio_chunks = []

    def on_start(self, ten_env_tester: TenEnvTester) -> None:
        """Called when test starts, sends a TTS request."""
        ten_env_tester.log_info("Dump test started, sending TTS request.")

        tts_input = TTSTextInput(
            request_id="tts_request_1",
            text="hello word, hello agora",
            text_input_end=True,
        )
        data = Data.create("tts_text_input")
        data.set_property_from_json(None, tts_input.model_dump_json())
        ten_env_tester.send_data(data)
        ten_env_tester.on_start_done()

    def on_data(self, ten_env: TenEnvTester, data) -> None:
        name = data.get_name()
        if name == "tts_audio_end":
            ten_env.log_info("Received tts_audio_end, stopping test.")
            self.audio_end_received = True
            ten_env.stop_test()

    def on_audio_frame(self, ten_env: TenEnvTester, audio_frame):
        """Receives audio frames and collects their data using the lock/unlock pattern."""
        # The 'audio_frame' object is a wrapper around a memory buffer.
        # We must lock the buffer to safely access the data, copy it,
        # and finally unlock the buffer so the runtime can reuse it.
        buf = audio_frame.lock_buf()
        try:
            # We must copy the data from the buffer, as the underlying memory
            # may be freed or reused after we unlock it.
            copied_data = bytes(buf)
            self.received_audio_chunks.append(copied_data)
        finally:
            # Always ensure the buffer is unlocked, even if an error occurs.
            audio_frame.unlock_buf(buf)

    def write_test_dump_file(self):
        """Writes the collected audio chunks to a file."""
        with open(self.test_dump_file_path, "wb") as f:
            for chunk in self.received_audio_chunks:
                f.write(chunk)

    def find_tts_dump_file(self) -> str | None:
        """Find the dump file created by the TTS extension in the fixed dump directory."""
        if not os.path.exists(self.dump_dir):
            return None
        for filename in os.listdir(self.dump_dir):
            if filename.endswith(".pcm") and filename != os.path.basename(
                self.test_dump_file_path
            ):
                return os.path.join(self.dump_dir, filename)
        return None


@patch("fish_audio_tts_python.extension.FishAudioTTSClient")
def test_dump_functionality(MockFishAudioTTSClient):
    """Tests that the dump file from the TTS extension matches the audio received by the test extension."""
    print("Starting test_dump_functionality with mock...")

    # --- Directory Setup ---
    # As requested, use a fixed './dump/' directory.
    DUMP_PATH = "./dump/"

    # Clean up directory before the test, in case of previous failed runs.
    if os.path.exists(DUMP_PATH):
        shutil.rmtree(DUMP_PATH)
    os.makedirs(DUMP_PATH)

    # --- Mock Configuration ---
    mock_instance = MockFishAudioTTSClient.return_value
    mock_instance.clean = MagicMock()

    # Create some fake audio data to be streamed
    fake_audio_chunk_1 = b"\x11\x22\x33\x44" * 20
    fake_audio_chunk_2 = b"\xaa\xbb\xcc\xdd" * 20

    # This async generator simulates the TTS client's get() method
    async def mock_get_audio_stream(text: str):
        yield (fake_audio_chunk_1, EVENT_TTS_RESPONSE)
        await asyncio.sleep(0.01)
        yield (fake_audio_chunk_2, EVENT_TTS_RESPONSE)
        await asyncio.sleep(0.01)
        yield (None, EVENT_TTS_END)

    mock_instance.get.side_effect = mock_get_audio_stream

    # --- Test Setup ---
    tester = ExtensionTesterDump()

    dump_config = {
        "dump": True,
        "dump_path": DUMP_PATH,
        "params": {
            "api_key": "test_api_key",
        },
    }

    tester.set_test_mode_single(
        "fish_audio_tts_python", json.dumps(dump_config)
    )

    print("Running dump test...")
    tester.run()
    print("Dump test completed.")

    # --- Verification ---
    # 1. Verify audio end was received
    assert tester.audio_end_received, "Expected to receive tts_audio_end"
    assert (
        len(tester.received_audio_chunks) > 0
    ), "Expected to receive audio chunks"

    # 2. Write received audio chunks to test file for comparison
    tester.write_test_dump_file()

    # 3. Find the dump file created by the extension
    tts_dump_file = tester.find_tts_dump_file()
    assert (
        tts_dump_file is not None
    ), f"Expected to find a TTS dump file in {DUMP_PATH}"
    assert os.path.exists(
        tts_dump_file
    ), f"TTS dump file should exist: {tts_dump_file}"

    # 4. Compare the files
    print(
        f"Comparing test file {tester.test_dump_file_path} with TTS dump file {tts_dump_file}"
    )
    assert filecmp.cmp(
        tester.test_dump_file_path, tts_dump_file, shallow=False
    ), "Test dump file and TTS dump file should have the same content"

    print(
        f"✅ Dump functionality test passed: received {len(tester.received_audio_chunks)} audio chunks"
    )
    print(f"   Test file: {tester.test_dump_file_path}")
    print(f"   TTS dump file: {tts_dump_file}")

    # --- Cleanup ---
    if os.path.exists(DUMP_PATH):
        shutil.rmtree(DUMP_PATH)


# ================ test flush logic ================
class ExtensionTesterFlush(ExtensionTester):
    def __init__(self):
        super().__init__()
        self.ten_env: TenEnvTester | None = None
        self.audio_start_received = False
        self.first_audio_frame_received = False
        self.flush_start_received = False
        self.audio_end_received = False
        self.flush_end_received = False
        self.audio_end_reason = ""
        self.total_audio_duration_from_event = 0
        self.received_audio_bytes = 0
        self.sample_rate = 24000  # OpenAI TTS sample rate
        self.bytes_per_sample = 2  # 16-bit
        self.channels = 1
        self.audio_received_after_flush_end = False

    def on_start(self, ten_env_tester: TenEnvTester) -> None:
        self.ten_env = ten_env_tester
        ten_env_tester.log_info("Flush test started, sending long TTS request.")
        tts_input = TTSTextInput(
            request_id="tts_request_for_flush",
            text="This is a very long text designed to generate a continuous stream of audio, providing enough time to send a flush command.",
        )
        data = Data.create("tts_text_input")
        data.set_property_from_json(None, tts_input.model_dump_json())
        ten_env_tester.send_data(data)
        ten_env_tester.on_start_done()

    def on_audio_frame(self, ten_env: TenEnvTester, audio_frame):
        if self.flush_end_received:
            ten_env.log_error("Received audio frame after tts_flush_end!")
            self.audio_received_after_flush_end = True

        if not self.first_audio_frame_received:
            self.first_audio_frame_received = True
            ten_env.log_info("First audio frame received, sending flush data.")
            flush_data = Data.create("tts_flush")
            flush_data.set_property_from_json(
                None,
                TTSFlush(flush_id="tts_request_for_flush").model_dump_json(),
            )
            ten_env.send_data(flush_data)

        buf = audio_frame.lock_buf()
        try:
            self.received_audio_bytes += len(buf)
        finally:
            audio_frame.unlock_buf(buf)

    def on_data(self, ten_env: TenEnvTester, data) -> None:
        name = data.get_name()
        ten_env.log_info(f"on_data name: {name}")

        if name == "tts_audio_start":
            self.audio_start_received = True
            return

        json_str, _ = data.get_property_to_json(None)
        if not json_str:
            return
        payload = json.loads(json_str)
        ten_env.log_info(f"on_data payload: {payload}")

        if name == "tts_flush_start":
            self.flush_start_received = True
            return

        if name == "tts_audio_end":
            self.audio_end_received = True
            self.audio_end_reason = payload.get("reason")
            self.total_audio_duration_from_event = payload.get(
                "request_total_audio_duration_ms"
            )

        elif name == "tts_flush_end":
            self.flush_end_received = True

            def stop_test_later():
                ten_env.log_info("Waited after flush_end, stopping test now.")
                ten_env.stop_test()

            timer = threading.Timer(0.5, stop_test_later)
            timer.start()

    def get_calculated_audio_duration_ms(self) -> int:
        duration_sec = self.received_audio_bytes / (
            self.sample_rate * self.bytes_per_sample * self.channels
        )
        return int(duration_sec * 1000)


@patch("fish_audio_tts_python.extension.FishAudioTTSClient")
def test_flush_logic(MockFishAudioTTSClient):
    """
    Tests that sending a flush command during TTS streaming correctly stops
    the audio and sends the appropriate events.
    """
    print("Starting test_flush_logic with mock...")

    mock_instance = MockFishAudioTTSClient.return_value
    mock_instance.clean = MagicMock()

    async def mock_get_long_audio_stream(text: str):
        for _ in range(20):
            # In a real scenario, the cancel() call would set a flag.
            # We simulate this by checking the mock's 'called' status.
            if mock_instance.cancel.called:
                print("Mock detected cancel call, sending EVENT_TTS_FLUSH.")
                yield (None, EVENT_TTS_FLUSH)
                return  # Stop the generator immediately after flush
            yield (b"\x11\x22\x33" * 100, EVENT_TTS_RESPONSE)
            await asyncio.sleep(0.1)

        # This part is only reached if not cancelled - normal completion
        yield (None, EVENT_TTS_END)

    mock_instance.get.side_effect = mock_get_long_audio_stream

    config = {
        "params": {
            "api_key": "test_api_key",
        },
    }
    tester = ExtensionTesterFlush()
    tester.set_test_mode_single("fish_audio_tts_python", json.dumps(config))

    print("Running flush logic test...")
    tester.run()
    print("Flush logic test completed.")

    assert tester.audio_start_received, "Did not receive tts_audio_start."
    assert tester.first_audio_frame_received, "Did not receive any audio frame."
    assert tester.audio_end_received, "Did not receive tts_audio_end."
    assert tester.flush_end_received, "Did not receive tts_flush_end."
    assert (
        not tester.audio_received_after_flush_end
    ), "Received audio after tts_flush_end."

    calculated_duration = tester.get_calculated_audio_duration_ms()
    event_duration = tester.total_audio_duration_from_event
    print(
        f"calculated_duration: {calculated_duration}, event_duration: {event_duration}"
    )
    assert (
        abs(calculated_duration - event_duration) < 10
    ), f"Mismatch in audio duration. Calculated: {calculated_duration}ms, From event: {event_duration}ms"

    print("✅ Flush logic test passed successfully.")
